import numpy as np
import os, pickle
from PIL import Image, ImageDraw, ImageFont
from viz_hand_obj import *


def calculate_center(bb):
    return [(bb[0] + bb[2])/2, (bb[1] + bb[3])/2]

def filter_object(obj_dets, hand_dets):
    object_cc_list = [] # object center list
    for j in range(obj_dets.shape[0]):
        object_cc_list.append(calculate_center(obj_dets[j,:4]))
    object_cc_list = np.array(object_cc_list)

    img_obj_id = [] # matching list
    for i in range(hand_dets.shape[0]):
        if hand_dets[i, 5] <= 0: # if hand is non-contact
            img_obj_id.append(-1)
            continue
        else: # hand is in-contact
            hand_cc = np.array(calculate_center(hand_dets[i,:4])) # hand center points
            point_cc = np.array([(hand_cc[0]+hand_dets[i,6]*10000*hand_dets[i,7]), (hand_cc[1]+hand_dets[i,6]*10000*hand_dets[i,8])]) # extended points (hand center + offset)
            dist = np.sum((object_cc_list - point_cc)**2,axis=1)
            dist_min = np.argmin(dist) # find the nearest 
            img_obj_id.append(dist_min)
        
    return img_obj_id


if __name__ == '__main__':

    ##############################################################
    # save the detection results in a pickle file, in dictionary format like

    # {"input_image_path_1":{
    #                        "hand_dets": hand_dets
    #                        ""'obj_dets": obj_dets
    #                       }
    # "input_image_path_2":{
    #                        "hand_dets": hand_dets
    #                        "obj_dets": obj_dets
    #                       }
    # ....
    # }

    # with img_obj_id generated by filter_object() function, hand_dets and obj_dets can be matched
    
    ##############################################################

    pickle_path = 'path/to/saved/pickle/file'
    save_dir = 'path/to/save'
    os.makedirs(save_dir, exist_ok=True)
    thresh_hand = 0.5
    thresh_obj = 0.5
    viz = True

    with open(pickle_path, 'rb') as f:
        pickle_info = pickle.load(f)

        for image_path, image_info in pickle_info.items():
            hand_dets = image_info['hand_dets']
            obj_dets = image_info['obj_dets']
            image_name = os.path.split(image_path)[-1]

            if viz: 
                image = Image.open(image_path).convert("RGBA")
                draw = ImageDraw.Draw(image)
                font = ImageFont.truetype('times_b.ttf', size=30)
                width, height = image.size


            if (obj_dets is not None) and (hand_dets is not None):
                # get matching list
                img_obj_id = filter_object(obj_dets, hand_dets)

                # obj
                for obj_idx, i in enumerate(range(np.minimum(10, obj_dets.shape[0]))):
                    obj_bbox = list(int(np.round(x)) for x in obj_dets[i, :4])
                    obj_score = obj_dets[i, 4]

                    # viz
                    if viz and obj_score > thresh_obj and i in img_obj_id: # draw obj if > threshold and matched with a hand
                        image = draw_obj_mask(image, draw, obj_idx, obj_bbox, obj_score, width, height, font)


                # hand
                for hand_idx, i in enumerate(range(np.minimum(10, hand_dets.shape[0]))):
                    hand_bbox = list(int(np.round(x)) for x in hand_dets[i, :4])
                    hand_score = hand_dets[i, 4]
                    hand_state = hand_dets[i, 5]
                    hand_vec = hand_dets[i, 6:9]
                    hand_lr = hand_dets[i, -1]
                    matched_obj = obj_dets[img_obj_id[i],:4]

                    # viz
                    
                    if viz and hand_score > thresh_hand:
                        image = draw_hand_mask(image, draw, hand_idx, hand_bbox, hand_score, hand_lr, hand_state, width, height, font)
                        if hand_state > 0: # in contact hand

                            obj_cc, hand_cc =  calculate_center(matched_obj), calculate_center(hand_bbox)
                            side_idx = int(hand_lr)
                            draw_line_point(draw, side_idx, (int(hand_cc[0]), int(hand_cc[1])), (int(obj_cc[0]), int(obj_cc[1])))

            elif hand_dets is not None:
                
                for hand_idx, i in enumerate(range(np.minimum(10, hand_dets.shape[0]))):
                    hand_bbox = list(int(np.round(x)) for x in hand_dets[i, :4])
                    hand_score = hand_dets[i, 4]
                    hand_state = hand_dets[i, 5]
                    hand_lr = hand_dets[i, -1]

                    if viz and hand_score > thresh_hand:
                        image = draw_hand_mask(image, draw, hand_idx, hand_bbox, hand_score, hand_lr, hand_state, width, height, font)

            if viz:
                save_path = os.path.join(save_dir, image_name)[:-4] +'_det.png'
                image.save(save_path)
                                        

        
